#!/usr/bin/python
# -*- coding: utf-8 -*-
from __future__ import division
import sys, argparse
import csv
import time as t
import datetime as d
import numpy as np
import pandas as pd
from sklearn.cross_validation import StratifiedShuffleSplit
'''
将CSV格式的 gbdt数据格式
'''

def parse_args():
    if len(sys.argv) == 1:
        sys.argv.append('-h')

    parser = argparse.ArgumentParser()
    parser.add_argument('train_feature')
    parser.add_argument('valid_ratio')
    parser.add_argument('with_label')
    args = vars(parser.parse_args())
    return args

# 数据中 pos/neg = 4:1

#TODO shuffle the original train

# 输入是完整的特征
args = parse_args()

train_feature = args['train_feature']
valid_ratio = float(args['valid_ratio'])
with_label = int(args['with_label']) > 0
print 'with_label:\t', with_label

train = pd.read_csv(train_feature)

train.drop(['enrollment_id'], axis=1, inplace=True)

trainset = train.values
validset = []

if valid_ratio > 0.0:
    sss = StratifiedShuffleSplit(train.label, test_size=valid_ratio, random_state=1234)
    for train_index, test_index in sss:
        break
    trainset = train.values[train_index]
    validset = train.values[test_index]

output_train_path = "%s.train.svm" % train_feature
output_valid_path = "%s.valid.svm" % train_feature

print 'output trainset to ', output_train_path
print 'output validset to ', output_valid_path

def rcd2line(rcd):
    output = []
    if with_label:
        output.append(str(int(rcd[0])))
        for i, r in enumerate(rcd[1:]):
            output.append('%d:%f' % (i, r))
    else:
        #output.append('0')
        for i, r in enumerate(rcd):
            output.append('%d:%f' % (i, r))
    return ' '.join(output)

with open(output_train_path, 'w') as f, open(output_valid_path, 'w') as g:
    train_values = trainset
    valid_values = validset
        
    for rcd in train_values:
        #rcd = [str(i) for i in rcd]
        line = rcd2line(rcd)
        f.write( line +  '\n')

    for rcd in valid_values:
        line = rcd2line(rcd)
        g.write( line + '\n')
